# Simulate multiple loci coalescent process with heterochronous sampling 

# Assumptions and modifications
# - truncates individual trees
# - focuses on bottleneck and cyclic trajectories
# - generates multiple conditionally independent trees (data)
# - deposits multiple trees in a single folder
# - samples placed uniformly across time period


# Clean the workspace and console
closeAllConnections()
rm(list=ls())
cat("\014")  
graphics.off()

# Packages for phylodyn
library("sp")
library("devtools")
library("INLA")
library("spam")
library("ape")
library("phylodyn")

# Set working directory to source
this.dir <- dirname(parent.frame(2)$ofile)
setwd(this.dir)

# Function to write simple csv files to correct path
tableWrite <- function(val, name, pathname) {
  # Add path to name
  str0 <- paste(c(pathname, name), collapse = "")
  # Write table
  write.table(val, str0, row.names=FALSE, col.names=FALSE, sep=",")
}

# Define a middling bottleneck
bottle_traj <- function (t) 
{
  result = rep(0, length(t))
  result[t <= 15] <- 200
  result[t > 15 & t < 40] <- 20
  result[t >= 40] <- 200
  return(result)
}

# Define a boom-bust with a later changepoint and an offset
boom_traj <- function (t, bust = 20, scale = 1000, offset = 100) 
{
  result = rep(0, length(t))
  result[t <= bust] = scale*exp(t[t <= bust] - bust) + offset
  result[t > bust] = scale*exp(bust - t[t > bust]) + offset
  return(result)
}

# Define a logistic trajectory with larger N
N = 500; N0 = 0.01*N
logis_traj <- function (t, offset = 0, a = 2) 
{
  t = t + offset
  result = rep(0, length(t))
  result[(t%%12) <= 6] = N0 + N/(1 + exp((3 - (t[(t%%12) <= 6]%%12)) * a))
  result[(t%%12) > 6] = N0 + N/(1 + exp(((t[(t%%12) > 6]%%12) - 12 + 3) * a))
  return(result)
}

# Main code for heterochronous simulations ----------------------------------------------------------

# No. loci considered (independent trees)
numLoci = 20

# Choose trajectory case
trajCase = 5
trajNames = c('cyclicLoci', 'bottleLoci', 'boomLoci', 'steepLoci', 'logisLoci')

# Choose trajectory type
trajType = switch(trajCase,
                  "1"= cyclic_traj,
                  "2"= bottle_traj,
                  "3"= boom_traj,
                  "4"= steep_cyc_traj,
                  "5"= logis_traj
)
traj = trajType
trajVal = trajNames[trajCase]

# Uniform sampling across time
all_samp_end = 40 #often set to 60 for cyclic and bottle
nsamps = 801; ndivs = 20
# Sample times
samp_times = seq(0, all_samp_end, length.out = ndivs)

# Period of truncation and extra initial samples
truncTime = 85
tsamp0 = 20
if(trajCase == 2){
  tsamp1 = 20; tsamp2 = 20
}else{
  tsamp1 = 0; tsamp2 = 0
}

# Sample number and times
samps = c(rep(floor(nsamps/ndivs), ndivs-1), nsamps-(ndivs-1)*floor(nsamps/ndivs))
samp_times = seq(0, all_samp_end, length.out = ndivs)

# Create folder for traj specific results
trajName = paste(c(trajVal, '_', nsamps-1), collapse = '')
dir.create(file.path(this.dir, trajName))
pathf = paste(c(this.dir, '/', trajName, '/'), collapse = "")

# Coalescent events and max time for each trajectory
nc = rep(0, numLoci); tmax = rep(0, numLoci)

for (i in 1:numLoci) {
  
  # Simulate genealogy and get all times
  gene = coalsim(samp_times = samp_times, n_sampled = samps, traj = traj, lower_bound = 10, method = "thin")
  coal_times = gene$coal_times
  coalLin = gene$lineages
  
  # Truncate trees to trunc
  idtrunc = coal_times <= truncTime
  coal_times = coal_times[idtrunc]
  coalLin = coalLin[idtrunc]
  
  # TMRCA and no. coalescent events
  tmax[i] = max(coal_times)
  nc[i] = length(coal_times)
  
  # Export teajectory specific data for Matlab
  tableWrite(coal_times, paste(c('coaltimes', i, '.csv'), collapse = ''), pathf)
  tableWrite(coalLin, paste(c('coalLin', i, '.csv'), collapse = ''), pathf)
}

# Number of loci and coalescences
tableWrite(nc, 'nc.csv', pathf)
tableWrite(numLoci, 'numLoci.csv', pathf)

# No. samples, truncation and TMRCA
tableWrite(nsamps, 'ns.csv', pathf)
tableWrite(tmax, 'tmax.csv', pathf)
tableWrite(truncTime, 'truncTime.csv', pathf)

# Sample scheme
tableWrite(samp_times, 'samptimes.csv', pathf)
tableWrite(samps, 'sampIntro.csv', pathf)

# True population size
t = seq(0, max(tmax), length=20000); y = traj(t)
tableWrite(t, 'trajt.csv', pathf)
tableWrite(y, 'trajy.csv', pathf)
